{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 代码生成：extern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import te\n",
    "import numpy as np\n",
    "import tvm.testing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## add_pipeline extern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# from tvm.script import ir as I\n",
      "# from tvm.script import tir as T\n",
      "\n",
      "@I.ir_module\n",
      "class Module:\n",
      "    @T.prim_func\n",
      "    def main(A: T.Buffer((64,), \"float32\"), C: T.Buffer((64,), \"float32\")):\n",
      "        T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n",
      "        T.attr(0, \"extern_scope\", 0)\n",
      "        for i in range(32):\n",
      "            cse_var_1: T.int32 = i * 2\n",
      "            C[cse_var_1:cse_var_1 + 2] = A[cse_var_1:cse_var_1 + 2] + T.Broadcast(T.float32(1), 2)\n",
      "# from tvm.script import ir as I\n",
      "# from tvm.script import tir as T\n",
      "\n",
      "@I.ir_module\n",
      "class Module:\n",
      "    @T.prim_func\n",
      "    def main(A: T.Buffer((64,), \"float32\"), C: T.Buffer((64,), \"float32\")):\n",
      "        T.func_attr({\"from_legacy_te_schedule\": T.bool(True), \"tir.noalias\": T.bool(True)})\n",
      "        T.attr(0, \"extern_scope\", 0)\n",
      "        blockIdx_x = T.launch_thread(\"blockIdx.x\", 16)\n",
      "        threadIdx_x = T.launch_thread(\"threadIdx.x\", 4)\n",
      "        C[blockIdx_x * 8 + threadIdx_x * 2:blockIdx_x * 8 + threadIdx_x * 2 + 2] = A[blockIdx_x * 8 + threadIdx_x * 2:blockIdx_x * 8 + threadIdx_x * 2 + 2] + T.Broadcast(T.float32(1), 2)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[08:59:06] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n",
      "[08:59:06] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n",
      "[08:59:06] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "nn = 64\n",
    "max_threads = 4\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "\n",
    "def extern_generator(ins, outs):\n",
    "    \"\"\"Manually write the IR for the extern function, add pipeline\"\"\"\n",
    "    ib = tvm.tir.ir_builder.create()\n",
    "    with ib.for_range(0, (n + 1) // 2) as i:\n",
    "        ib.emit(\n",
    "            outs[0].vstore(\n",
    "                i * 2, ins[0].vload(i * 2, \"float32x2\") + tvm.tir.const(1, \"float32x2\")\n",
    "            )\n",
    "        )\n",
    "    return ib.get()\n",
    "\n",
    "def extern_generator_gpu(ins, outs):\n",
    "    \"\"\"Manually write the IR for the extern function, add pipeline\"\"\"\n",
    "    ib = tvm.tir.ir_builder.create()\n",
    "    bx = te.thread_axis(\"blockIdx.x\")\n",
    "    tx = te.thread_axis(\"threadIdx.x\")\n",
    "    ib.scope_attr(bx, \"thread_extent\", (nn + max_threads - 1) // max_threads)\n",
    "    ib.scope_attr(tx, \"thread_extent\", max_threads)\n",
    "    idx = bx.var * max_threads + tx.var\n",
    "    with ib.if_scope(ib.likely(idx < n)):\n",
    "        ib.emit(\n",
    "            outs[0].vstore(\n",
    "                idx * 2, ins[0].vload(idx * 2, \"float32x2\") + tvm.tir.const(1, \"float32x2\")\n",
    "            )\n",
    "        )\n",
    "    return ib.get()\n",
    "\n",
    "C_cpu = te.extern(A.shape, [A], extern_generator, name=\"C\")\n",
    "C_gpu = te.extern(A.shape, [A], extern_generator_gpu, name=\"C\")\n",
    "s_cpu = te.create_schedule(C_cpu.op)\n",
    "s_gpu = te.create_schedule(C_gpu.op)\n",
    "print(tvm.lower(s_cpu, [A, C_cpu], simple_mode=True))\n",
    "print(tvm.lower(s_gpu, [A, C_gpu], simple_mode=True))\n",
    "\n",
    "def check_target(target):\n",
    "    if not tvm.testing.device_enabled(target):\n",
    "        return\n",
    "    s = s_gpu if target in [\"opencl\", \"cuda\"] else s_cpu\n",
    "    C = C_gpu if target in [\"opencl\", \"cuda\"] else C_cpu\n",
    "    # build and invoke the kernel.\n",
    "    f = tvm.build(s, [A, C], target)\n",
    "    dev = tvm.device(target, 0)\n",
    "    # launch the kernel.\n",
    "    n = nn\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "    f(a, c)\n",
    "    tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)\n",
    "\n",
    "check_target(\"llvm\")\n",
    "check_target(\"opencl\")\n",
    "check_target(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简单 pack buffer extern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09:00:33] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n",
      "[09:00:33] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "\n",
    "def extern_generator(ins, outs):\n",
    "    \"\"\"Manually write the IR for the extern function, add pipeline.\"\"\"\n",
    "    return tvm.tir.call_packed(\"my_extern_array_func1\", ins[0], outs[0])\n",
    "\n",
    "C = te.extern(A.shape, [A], extern_generator, name=\"C\")\n",
    "s = te.create_schedule(C.op)\n",
    "\n",
    "@tvm.register_func\n",
    "def my_extern_array_func1(aa, bb):\n",
    "    aa.copyto(bb)\n",
    "\n",
    "def check_target(target):\n",
    "    if not tvm.testing.device_enabled(target):\n",
    "        return\n",
    "    # build and invoke the kernel.\n",
    "    f = tvm.build(s, [A, C], target)\n",
    "    dev = tvm.cpu(0)\n",
    "    # launch the kernel.\n",
    "    n = nn\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "\n",
    "    f(a, c)\n",
    "    tvm.testing.assert_allclose(c.numpy(), a.numpy())\n",
    "\n",
    "check_target(\"stackvm\")\n",
    "check_target(\"llvm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## pack_buffer_intermediate extern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09:02:12] /media/pc/data/lxw/ai/tvm/src/target/target_kind.cc:158: Warning: Unable to detect CUDA version, default to \"-arch=sm_50\" instead\n"
     ]
    }
   ],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.compute((n,), lambda i: A[i] + 1, name=\"B\")\n",
    "\n",
    "def extern_generator(ins, outs):\n",
    "    \"\"\"Manually write the IR for the extern function, add pipeline.\"\"\"\n",
    "    return tvm.tir.call_packed(\"my_extern_array_func2\", ins[0], outs[0])\n",
    "\n",
    "C = te.extern(B.shape, [B], extern_generator, name=\"C\")\n",
    "s = te.create_schedule(C.op)\n",
    "\n",
    "def check_target(target):\n",
    "    if not tvm.testing.device_enabled(target):\n",
    "        return\n",
    "    # build and invoke the kernel.\n",
    "    f = tvm.build(s, [A, C], target)\n",
    "    dev = tvm.cpu(0)\n",
    "    # launch the kernel.\n",
    "    n = nn\n",
    "    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "\n",
    "    @tvm.register_func\n",
    "    def my_extern_array_func2(aa, bb):\n",
    "        assert aa.shape == a.shape\n",
    "        tvm.testing.assert_allclose(aa.numpy(), a.numpy() + 1)\n",
    "        aa.copyto(bb)\n",
    "\n",
    "    f(a, c)\n",
    "    tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)\n",
    "\n",
    "check_target(\"llvm\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aix",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
